1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707
|
#include <unordered_map>
#include <c10/core/impl/alloc_cpu.h>
#include <c10/core/Allocator.h>
#include <c10/core/ScalarType.h>
#include <c10/util/ArrayRef.h>
#include <torch/csrc/Device.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/macros/Macros.h>
#include <torch/extension.h>
#include <ATen/native/cpu/Loops.h>
#include <ATen/native/quantized/AffineQuantizer.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/Resize.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/CPUFallback.h>
#include <ATen/ops/abs_native.h>
#include <ATen/EmptyTensor.h>
#include <ATen/core/GeneratorForPrivateuseone.h>
#include <ATen/detail/PrivateUse1HooksInterface.h>
#include <ATen/ops/view.h>
#include <ATen/native/transformers/sdp_utils_cpp.h>
#include <ATen/native/transformers/attention.h>
static uint64_t add_counter = 0;
static uint64_t last_saved_value = 0;
static c10::DeviceIndex custom_device_index = 0;
static uint64_t abs_counter = 0;
static uint64_t last_abs_saved_value = 0;
static uint64_t storageImpl_counter = 0;
static uint64_t last_storageImpl_saved_value = 0;
// register guard
namespace at {
namespace detail {
C10_REGISTER_GUARD_IMPL(
PrivateUse1,
c10::impl::NoOpDeviceGuardImpl<DeviceType::PrivateUse1>);
}} // namespace at::detail
namespace {
// Using the simplest way to obtain continuous Tensor data and process it.
// This is a demo for using operand API, and you can add more complex logic
// for input and output tensor based on your custom device kernel.
void abs_kernel(at::TensorIteratorBase& iter) {
// Abs only have a input tensor and a output tensor.
auto& output_operand = iter.operand(0);
auto& input_operand = iter.operand(1);
auto& output_tensor_base = output_operand.tensor_base();
auto& input_tensor_base = input_operand.tensor_base();
TORCH_CHECK(!input_operand.original_tensor_base().defined(),
"input original tensor is defined.");
TORCH_CHECK(!output_operand.original_tensor_base().defined(),
"output original tensor is defined.");
// For easy test, only accept contiguous input tensor for calculate.
auto memory_format = input_tensor_base.suggest_memory_format();
TORCH_CHECK(input_tensor_base.is_contiguous(memory_format),
"Input tensor need be contiguous.");
// Add necessary restrictions to ensure the security of the demo.
TORCH_CHECK(input_tensor_base.sizes() == output_tensor_base.sizes(),
"Intput and output tensor size are not equal.");
// Common dtype is calculate in TensorIteratorBase.
TORCH_CHECK(iter.common_dtype() == at::ScalarType::Float,
"Only support float type.")
// Using for loop for abs calculate.
auto abs_function = [](float* output_ptr, const float* input_ptr,
const int64_t NUM) {
for (int64_t i = 0; i < NUM; ++i) {
*(output_ptr + i) = std::abs(*(input_ptr + i));
}
};
// To simplify the logic of the test demo code,
// we only use contiguous tensor to calculate on device side.
// And using input tensor memory format.
if (iter.is_contiguous()) {
// Add for will_resize flag check. You can convert to differernt
// tensor memory format when will_resize is True.
// If TensorIteratorConfig resize_outputs_ flag is true, and there are two
// situations:
// 1) Out tensor is undefined, and TensorIterator set will_resize to true;
// 2) Out tensor is defined and tensor size is not equal to input tensor size;
// TensorIterator set will_resize to true, and call set_output_raw_strided
// to resize output tensor.
// When output operand will_resize flag is ture, dummy
// device can convert tensor to dummy device preferred memory format.
// Here we don't convert tensor memory format, because it will become complex
// when dummy device want keep same memory format for training network.
TORCH_CHECK(output_operand.will_resize,
"output operand will_resize flag need be True.");
abs_function((float*)iter.data_ptr(0), (float*)iter.data_ptr(1), iter.numel());
} else {
// Stride copy is not support for foo device, using cpu device instead.
// For abs op, the last situation is: output tensor is not contiguous with
// operand will_resize is False.
TORCH_CHECK(!output_operand.will_resize, "output operand will_resize is True.");
// Get a contiguous tensor with input memory format.
at::Tensor output = at::empty(output_tensor_base.sizes(),
input_tensor_base.options()
.memory_format(memory_format));
// For structured op which inheried from TensorIteratorBase, maybe you need to
// call set_output_raw_strided function to update output stored in op sturctured.
// abs op is no need to do this.
output_operand.exchange_tensor(c10::MaybeOwned<at::TensorBase>::owned(std::in_place, output));
abs_function((float*)output_operand.tensor_base().mutable_data_ptr(),
(float*)iter.data_ptr(1), iter.numel());
// Copy tensor base to original tensor base, and keep same scalar type and
// stride with cpu and gpu.
if (output_operand.original_tensor_base().defined() &&
!output_operand.original_tensor_base().is_same(output_operand.tensor_base())) {
output_operand.original_tensor().copy_(output_operand.tensor());
output_operand.restore_original_tensor();
}
}
}
void quantize_tensor_per_tensor_affine_privateuse1(
const at::Tensor& rtensor,
at::Tensor& qtensor,
double scale,
int64_t zero_point) {
// do nothing
}
int64_t _fused_sdp_choice_privateuse1(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value,
const std::optional<at::Tensor> & attn_mask, double dropout_p, bool is_causal, std::optional<double> scale, bool enable_gqa){
auto backend = sdp::SDPBackend::overrideable;
return static_cast<int64_t>(backend);
}
} // namespace
namespace at::native {
REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &abs_kernel);
REGISTER_PRIVATEUSE1_DISPATCH(quantize_tensor_per_tensor_affine_stub, &quantize_tensor_per_tensor_affine_privateuse1);
REGISTER_PRIVATEUSE1_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_privateuse1);
} // namespace at::native
struct CustomBackendMetadata : public c10::BackendMeta {
// for testing this field will mutate when clone() is called by shallow_copy_from.
int backend_version_format_{-1};
int format_number_{-1};
mutable bool cloned_{false};
// define the constructor
CustomBackendMetadata(int backend_version_format, int format_number) :
backend_version_format_(backend_version_format), format_number_(format_number) {}
c10::intrusive_ptr<c10::BackendMeta> clone(
const c10::intrusive_ptr<c10::BackendMeta>& ptr) const override {
cloned_ = true;
return c10::BackendMeta::clone(ptr);
}
};
// we need to register two functions for serialization
void for_serialization(const at::Tensor& t, std::unordered_map<std::string, bool>& m) {
if (t.unsafeGetTensorImpl()->get_backend_meta_intrusive_ptr() == nullptr) {
return;
}
auto tmeta = dynamic_cast<CustomBackendMetadata*>(t.unsafeGetTensorImpl()->get_backend_meta());
if (tmeta->backend_version_format_ == 1) {
m["backend_version_format"] = true;
}
if (tmeta->format_number_ == 29) {
m["format_number"] = true;
}
}
void for_deserialization(const at::Tensor& t, std::unordered_map<std::string, bool>& m) {
int backend_version_format{-1};
int format_number{-1};
if (m.find("backend_version_format") != m.end()) {
backend_version_format = 1;
}
if (m.find("format_number") != m.end()) {
format_number = 29;
}
c10::intrusive_ptr<c10::BackendMeta> new_tmeta{std::unique_ptr<c10::BackendMeta>(
new CustomBackendMetadata(backend_version_format, format_number))};
t.unsafeGetTensorImpl()->set_backend_meta(new_tmeta);
}
void custom_serialization_registry() {
torch::jit::TensorBackendMetaRegistry(c10::DeviceType::PrivateUse1,
&for_serialization,
&for_deserialization);
}
//check if BackendMeta serialization correctly
bool check_backend_meta(const at::Tensor& t) {
if (t.unsafeGetTensorImpl()->get_backend_meta_intrusive_ptr()) {
CustomBackendMetadata* tmeta = dynamic_cast<CustomBackendMetadata*>(
t.unsafeGetTensorImpl()->get_backend_meta());
if (tmeta->backend_version_format_==1 && tmeta->format_number_==29) {
return true;
}
}
return false;
}
// a fake set function is exposed to the Python side
void custom_set_backend_meta(const at::Tensor& t) {
int backend_version_format{1};
int format_number{29};
c10::intrusive_ptr<c10::BackendMeta> new_tmeta{std::unique_ptr<c10::BackendMeta>(
new CustomBackendMetadata(backend_version_format, format_number))};
t.unsafeGetTensorImpl()->set_backend_meta(new_tmeta);
}
// A dummy storageImpl for our custom device, that secretly uses the CPU
c10::intrusive_ptr<c10::StorageImpl> make_custom_storage_impl(c10::StorageImpl::use_byte_size_t,
c10::SymInt size_bytes,
c10::DataPtr data_ptr,
c10::Allocator* allocator,
bool resizable) {
c10::intrusive_ptr<c10::StorageImpl> custom_storage_impl;
if (data_ptr == nullptr){
custom_storage_impl = c10::make_intrusive<c10::StorageImpl>(
c10::StorageImpl::use_byte_size_t(), size_bytes, allocator, resizable);
} else {
custom_storage_impl = c10::make_intrusive<c10::StorageImpl>(
c10::StorageImpl::use_byte_size_t(), size_bytes, std::move(data_ptr), allocator, resizable);
}
storageImpl_counter += 1;
return custom_storage_impl;
}
// Register our dummy storageImpl create method.
void custom_storage_registry() {
c10::SetStorageImplCreate(c10::DeviceType::PrivateUse1, &make_custom_storage_impl);
}
bool custom_storageImpl_called() {
if (storageImpl_counter > last_storageImpl_saved_value) {
last_storageImpl_saved_value = storageImpl_counter;
return true;
}
return false;
}
// basic dummy add function
at::Tensor custom_add_Tensor(const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha) {
add_counter += 1;
// Since this custom device is just for testing, not bothering to implement kernels.
return at::empty(self.sizes(), self.options());
}
// basic abs function
at::Tensor& custom_abs_out(const at::Tensor& self, at::Tensor& out) {
return at::native::abs_out(self, out);
}
// A dummy allocator for our custom device, that secretly uses the CPU
struct DummyCustomAllocator final : at::Allocator {
DummyCustomAllocator() = default;
at::DataPtr allocate(size_t nbytes) override {
void* data = c10::alloc_cpu(nbytes);
return {data, data, &ReportAndDelete, at::Device(at::DeviceType::PrivateUse1, custom_device_index)};
}
static void ReportAndDelete(void* ptr) {
if (!ptr) {
return;
}
c10::free_cpu(ptr);
}
at::DeleterFnPtr raw_deleter() const override {
return &ReportAndDelete;
}
void copy_data(void* dest, const void* src, std::size_t count) const final {
default_copy_data(dest, src, count);
}
};
// Register our dummy allocator
static DummyCustomAllocator global_custom_alloc;
REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_custom_alloc);
// basic dummy empty function, so we can directly construct tensors on the custom device
// This dummy test device will just use the CPU allocator, and ignores pinned memory.
at::Tensor custom_empty_memory_format(at::IntArrayRef size,
std::optional<at::ScalarType> dtype,
std::optional<at::Layout> layout,
std::optional<at::Device> device,
std::optional<bool> pin_memory,
std::optional<at::MemoryFormat> memory_format) {
constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
return at::detail::empty_generic(size,
&global_custom_alloc,
private_use_ks,
c10::dtype_or_default(dtype),
memory_format);
}
at::Tensor custom_empty_symint(c10::IntArrayRef size,
std::optional<at::ScalarType> dtype,
std::optional<at::Layout> layout,
std::optional<at::Device> device,
std::optional<bool> pin_memory,
std::optional<at::MemoryFormat> memory_format) {
constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
return at::detail::empty_generic(size,
&global_custom_alloc, private_use_ks, c10::dtype_or_default(dtype), memory_format);
}
at::Tensor & custom_fill__scalar(at::Tensor & self, const at::Scalar & value) {
// Not bothering to implement.
return self;
}
// Unsafe using dummy device data_ptr to creat a cpu tensor, and shared data_ptr.
at::Tensor unsafe_create_cpu_tensor_from_dummy_tensor(const at::Tensor& src) {
TORCH_CHECK(src.device().type() == c10::DeviceType::PrivateUse1,
"Only support dummy device.");
const auto& sizes_ = src.sizes();
const auto& strides_ = src.strides();
auto storage_offset_ = src.storage_offset();
at::detail::check_size_nonnegative(sizes_);
size_t size_bytes = at::detail::computeStorageNbytes(sizes_, strides_,
src.element_size(),
storage_offset_);
at::DataPtr data_ptr =
c10::InefficientStdFunctionContext::makeDataPtr(src.storage().mutable_data_ptr().get(),
[](void*){}, at::kCPU);
c10::Storage storage{c10::Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr),
/*allocator=*/&global_custom_alloc, /*resizeable=*/false};
constexpr c10::DispatchKeySet cpu_ks(c10::DispatchKey::CPU);
at::Tensor tensor = at::detail::make_tensor<c10::TensorImpl>(
std::move(storage), cpu_ks, src.dtype());
c10::TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
tensor_impl->set_sizes_and_strides(sizes_, strides_);
tensor_impl->set_storage_offset(storage_offset_);
return tensor;
}
// basic dummy copy_() function, so we can copy from the custom device to/from CPU
at::Tensor custom__copy_from(const at::Tensor& self, const at::Tensor& dst, bool non_blocking) {
TORCH_CHECK(
self.is_cpu() || self.device().type() == c10::DeviceType::PrivateUse1,
"Dummy test only allows copy from cpu -> dummy device.");
TORCH_CHECK(
dst.is_cpu() || dst.device().type() == c10::DeviceType::PrivateUse1,
"Dummy test only allows copy from cpu -> dummy device.");
// Some dummy asserts for the basic use case: inputs are the same size / dtype, all contiguous.
TORCH_CHECK(self.sizes() == dst.sizes());
TORCH_CHECK(self.scalar_type() == dst.scalar_type());
if (self.is_contiguous() && dst.is_contiguous()) {
std::memcpy(dst.storage().data_ptr().get(),
self.storage().data_ptr().get(),
self.storage().nbytes());
} else {
// Using cpu tensor to accomplishment stride copy.
auto convert_to_cpu_tensor = [](const at::Tensor& src) -> at::Tensor {
if (src.device().type() == c10::DeviceType::PrivateUse1) {
return unsafe_create_cpu_tensor_from_dummy_tensor(src);
} else {
return src;
}
};
at::Tensor cpu_self = convert_to_cpu_tensor(self);
at::Tensor cpu_dst = convert_to_cpu_tensor(dst);
cpu_dst.copy_(cpu_self);
}
return dst;
}
at::Tensor custom__copy_from_and_resize(const at::Tensor& self, const at::Tensor& dst) {
return custom__copy_from(self, dst, false);
}
at::Tensor custom_empty_strided(c10::IntArrayRef size,
c10::IntArrayRef stride,
std::optional<at::ScalarType> dtype_opt,
std::optional<at::Layout> layout_opt,
std::optional<at::Device> device_opt,
std::optional<bool> pin_memory_opt) {
constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
auto dtype = c10::dtype_or_default(dtype_opt);
return at::detail::empty_strided_generic(size, stride, &global_custom_alloc, private_use_ks, dtype);
}
// Some set operations for the basic use case
at::Tensor& custom_set_source_Storage(at::Tensor& result, c10::Storage src) {
int64_t new_size = static_cast<int64_t>(src.nbytes() / result.dtype().itemsize());
c10::IntArrayRef stride = {};
result.unsafeGetTensorImpl()->set_storage_offset(0);
at::OptionalIntArrayRef stride_opt = stride.data() != nullptr ? at::OptionalIntArrayRef(stride) : std::nullopt;
at::native::resize_impl_cpu_(result.unsafeGetTensorImpl(),
new_size, stride_opt,
/*resize_storage=*/!result.is_meta());
return result;
}
// Some set operations for the basic use case
at::Tensor& custom_set_source_Storage_storage_offset(at::Tensor& result,
c10::Storage storage,
int64_t storage_offset,
c10::IntArrayRef size,
c10::IntArrayRef stride) {
result.unsafeGetTensorImpl()->set_storage_offset(storage_offset);
at::OptionalIntArrayRef stride_opt = stride.data() != nullptr ? at::OptionalIntArrayRef(stride) : std::nullopt;
at::native::resize_impl_cpu_(result.unsafeGetTensorImpl(),
size, stride_opt,
/*resize_storage=*/!result.is_meta());
return result;
}
const at::Tensor& custom_resize_(const at::Tensor& self, at::IntArrayRef size,
std::optional<at::MemoryFormat> optional_memory_format) {
at::TensorImpl* tensor_impl = self.unsafeGetTensorImpl();
tensor_impl->set_sizes_contiguous(size);
const auto itemsize = tensor_impl->dtype().itemsize();
const auto offset = tensor_impl->storage_offset();
const auto storage_size = at::detail::computeStorageNbytesContiguous(size, itemsize, offset);
// Dummy device is using cpu allocator, so here just call cpu
// function maybe_resize_storage_cpu in aten/src/ATen/native/Resize.h
// to get a sufficient memory space.
at::native::maybe_resize_storage_cpu(tensor_impl, storage_size);
if (optional_memory_format.has_value()) {
auto memory_format =
optional_memory_format.value();
TORCH_CHECK(
memory_format != at::MemoryFormat::Preserve,
"Unsupported memory format",
memory_format);
tensor_impl->empty_tensor_restride(memory_format);
}
return self;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, c10::SymInt, c10::SymInt, at::Tensor, at::Tensor, at::Tensor>
custom_scaled_dot_product_fused_attention_overrideable(
const at::Tensor & query,
const at::Tensor & key,
const at::Tensor & value,
const std::optional<at::Tensor> & attn_bias,
double dropout_p,
bool is_causal,
bool return_debug_mask,
std::optional<double> scale) {
const int64_t batch_size = query.size(0);
const int64_t num_heads = query.size(1);
const int64_t head_dim_qk = query.size(3);
const int64_t head_dim_v = value.size(3);
const int64_t max_seqlen_q = query.size(2);
const int64_t max_seqlen_kv = key.size(2);
auto opts = query.options();
auto output = at::empty({batch_size, num_heads, max_seqlen_q, head_dim_v}, opts);
auto logsumexp = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
auto debug_attn_mask = at::empty({batch_size, num_heads, max_seqlen_q, max_seqlen_kv},
opts.dtype(at::kFloat));
auto philox_seed = at::empty({}, at::dtype(at::kLong));
auto philox_offset = at::empty({}, at::dtype(at::kLong));
return std::make_tuple(output, logsumexp, at::Tensor(), at::Tensor(), max_seqlen_q, max_seqlen_kv, philox_seed, philox_offset, debug_attn_mask);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
custom_scaled_dot_product_fused_attention_overrideable_backward(
const at::Tensor & grad_out,
const at::Tensor & query,
const at::Tensor & key,
const at::Tensor & value,
const at::Tensor & attn_bias,
std::array<bool,4> grad_input_mask,
const at::Tensor & out,
const at::Tensor & logsumexp,
const at::Tensor & cum_seq_q,
const at::Tensor & cum_seq_k,
int64_t max_q,
int64_t max_k,
double dropout_p,
bool is_causal,
const at::Tensor & philox_seed,
const at::Tensor & philox_offset,
std::optional<double> scale) {
return std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>(
at::empty_like(query),
at::empty_like(key),
at::empty_like(value),
at::empty_like(attn_bias));
}
// This macro does the heavy lifting.
// With TORCH_LIBRARY_IMPL, you can register custom kernels for your backend.
// For open registration, we're registering all of our kernels to the PrivateUse1 dispatch key.
// Later in this file, we map a custom device to the PrivateUse1 device type,
// which allows user code that puts a tensor on your custom_device to eventually get plumbed
// into the kernels registered here.
//
// This macro registers your kernels to the PyTorch Dispatcher.
// More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/.
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("abs.out", &custom_abs_out);
m.impl("add.Tensor", &custom_add_Tensor);
m.impl("empty.memory_format", &custom_empty_symint);
m.impl("fill_.Scalar", &custom_fill__scalar);
m.impl("_copy_from", &custom__copy_from);
m.impl("_copy_from_and_resize", &custom__copy_from_and_resize);
m.impl("empty_strided", &custom_empty_strided);
m.impl("set_.source_Storage", &custom_set_source_Storage);
m.impl("set_.source_Storage_storage_offset",&custom_set_source_Storage_storage_offset);
m.impl("resize_", &custom_resize_);
m.impl("as_strided", at::native::as_strided_tensorimpl);
m.impl("quantize_per_tensor", at::native::quantize_per_tensor);
m.impl("_fused_sdp_choice", &_fused_sdp_choice_privateuse1);
m.impl("_scaled_dot_product_fused_attention_overrideable", &custom_scaled_dot_product_fused_attention_overrideable);
m.impl("_scaled_dot_product_fused_attention_overrideable_backward", &custom_scaled_dot_product_fused_attention_overrideable_backward);
}
void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
at::native::cpu_fallback(op, stack);
}
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("sub.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
m.impl("_foreach_add.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
m.impl("_fused_adamw_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
m.impl("index.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
m.impl("triu_indices", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
}
// This basic implementation doesn't bother dealing with different device indices
// (e.g. custom_device:0 vs. custom_device:1).
// We could do that by letting the user pass in a device index in our exposed device function.
// Note that if you do that, you'll also need to register a device guard to core.
// See `c10/core/impl/DeviceGuardImplInterface.h:C10_REGISTER_GUARD_IMPL`.
c10::Device get_custom_device() {
return c10::Device(c10::DeviceType::PrivateUse1, 0);
}
bool custom_add_called() {
bool called = false;
if (add_counter > last_saved_value) {
called = true;
last_saved_value = add_counter;
}
return called;
}
class PrivateGeneratorImpl : public at::CPUGeneratorImpl {
public:
// Constructors
PrivateGeneratorImpl(c10::DeviceIndex device_index) {
device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index);
key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1);
}
~PrivateGeneratorImpl() override = default;
};
// this is used to register generator
at::Generator make_generator_privateuse1(c10::DeviceIndex device_index) {
return at::make_generator<PrivateGeneratorImpl>(device_index);
}
void register_generator_first() {
REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1)
}
void register_generator_second() {
REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1)
}
void set_custom_device_index(c10::DeviceIndex device_index) {
custom_device_index = device_index;
}
// a global flag used for dummy pin_memory of custom device
bool custom_pinned_flag = false;
struct FooHooksArgs : public at::PrivateUse1HooksArgs {};
struct FooHooksInterface : public at::PrivateUse1HooksInterface {
FooHooksInterface(FooHooksArgs) {}
~FooHooksInterface() override = default;
const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index) const override {
static auto device_gen = make_generator_privateuse1(device_index);
return device_gen;
}
// this is a simple implementation, custom_pinned_flag will be set as true
// once tensor.pin_memory() is called. And then tensor.is_pinned()
// always return true no matter what tensor it's called on.
bool isPinnedPtr(const void* data) const override {
return custom_pinned_flag;
}
c10::Allocator* getPinnedMemoryAllocator() const override {
custom_pinned_flag = true;
return c10::GetCPUAllocator();
}
};
TORCH_DECLARE_REGISTRY(PrivateUse1HooksRegistry, FooHooksInterface, FooHooksArgs);
C10_DEFINE_REGISTRY(PrivateUse1HooksRegistry, FooHooksInterface, FooHooksArgs)
// Using Create function to get PrivateUse1HooksInterface point from PrivateUse1HooksRegistry class.
C10_REGISTER_TYPED_CLASS(PrivateUse1HooksRegistry, "FooHooks", FooHooksInterface)
static at::PrivateUse1HooksInterface* privateuse1_hooks_local = nullptr;
static at::PrivateUse1HooksInterface* get_private_hooks() {
static c10::once_flag once;
c10::call_once(once, [] {
privateuse1_hooks_local = PrivateUse1HooksRegistry()->Create("FooHooks", {}).release();
if (!privateuse1_hooks_local) {
privateuse1_hooks_local = new FooHooksInterface(FooHooksArgs{});
}
});
return privateuse1_hooks_local;
}
void register_hook() {
at::RegisterPrivateUse1HooksInterface(get_private_hooks());
}
bool is_register_hook() {
return privateuse1_hooks_local != nullptr;
}
const at::Generator& default_generator(c10::DeviceIndex device_index) {
return at::globalContext().defaultGenerator(at::Device(c10::DeviceType::PrivateUse1, device_index));;
}
void fallback_with_undefined_tensor() {
at::Tensor first = at::empty((2,3)).to(at::DeviceType::PrivateUse1);
at::Tensor second = at::Tensor();
at::Tensor step = at::empty({}).fill_(2).to(at::DeviceType::PrivateUse1);
at::Tensor grad_scale = at::empty({}).fill_(0.00001).to(at::DeviceType::PrivateUse1);
at::Tensor found_inf = at::empty({}).fill_(1).to(at::DeviceType::PrivateUse1);
at::TensorList tensors = {first, first};
at::TensorList undefined_tensors = {first, second};
at::TensorList steps = {step, step};
return at::_fused_adamw_(tensors, tensors, tensors, tensors, undefined_tensors,
steps, 0.001, 0.9, 0.999, 1e-2, 1e-8, false, false,
grad_scale, found_inf);
}
struct CustomAutogradFnReturnsSelf : public torch::autograd::Function<CustomAutogradFnReturnsSelf> {
static at::Tensor forward(torch::autograd::AutogradContext* ctx, at::Tensor self) {
return self;
}
static torch::autograd::variable_list backward(torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output) {
return {grad_output[0] * 0.5};
}
};
struct CustomAutogradFnAliasing : public torch::autograd::Function<CustomAutogradFnAliasing> {
static at::Tensor forward(torch::autograd::AutogradContext* ctx, at::Tensor self) {
return self.view_symint(self.sym_sizes());
}
static torch::autograd::variable_list backward(torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output) {
return {grad_output[0] * 0.5};
}
};
at::Tensor custom_autograd_fn_returns_self(at::Tensor x) {
return CustomAutogradFnReturnsSelf::apply(x);
}
at::Tensor custom_autograd_fn_aliasing(at::Tensor x) {
return CustomAutogradFnAliasing::apply(x);
}
// Here, we're exposing a custom device object that corresponds to our custom backend.
// We do this using pybind: exposing an "extension_name.custom_device()" function in python,
// that's implemented in C++.
// The implementation in this file maps directly to the `PrivateUse1` device type.
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("custom_device", &get_custom_device, "get custom device object");
m.def("custom_add_called", &custom_add_called, "check if our custom add function was called");
m.def("register_generator_first", ®ister_generator_first, "register generator for custom device firstly");
m.def("register_generator_second", ®ister_generator_second, "register generator for custom device secondly");
m.def("set_custom_device_index", &set_custom_device_index, "set custom device index");
m.def("custom_storage_registry", &custom_storage_registry, "set custom storageImpl creat method");
m.def("custom_storageImpl_called", &custom_storageImpl_called, "check if our custom abs function was called");
m.def("custom_set_backend_meta", &custom_set_backend_meta, "a fake set tensor BackendMeta function");
m.def("check_backend_meta", &check_backend_meta, "check if BackendMeta serialization correctly");
m.def("custom_serialization_registry", &custom_serialization_registry, "register custom serialization function");
m.def("register_hook", ®ister_hook, "register_hook for privateuse1");
m.def("is_register_hook", &is_register_hook, "is_register_hook for privateuse1");
m.def("default_generator", &default_generator, "default_generator for privateuse1");
m.def("fallback_with_undefined_tensor", &fallback_with_undefined_tensor, "fallback_with_undefined_tensor for privateuse1");
// Co-opting this file to more easily test torch.compile'ing of custom autograd functions in C++
m.def("custom_autograd_fn_returns_self", &custom_autograd_fn_returns_self);
}
TORCH_LIBRARY(_test_funcs, m) {
m.def("custom_autograd_fn_aliasing(Tensor(a) input)-> Tensor(a)");
}
TORCH_LIBRARY_IMPL(_test_funcs, AutogradCPU, m) {
m.impl("custom_autograd_fn_aliasing", &custom_autograd_fn_aliasing);
}
|