1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
|
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/mm.h>
#endif
#include <torch/csrc/autograd/functions/accumulate_grad.h>
#include <torch/csrc/inductor/inductor_ops.h>
#include <torch/library.h>
#include <ATen/FunctionalTensorWrapper.h>
namespace torch::inductor {
using namespace at;
Tensor _mm_plus_mm_out(
Tensor& out,
const Tensor& a,
const Tensor& b,
const Tensor& c,
const Tensor& d) {
at::mm_out(out, a, b);
out.addmm_(c, d);
return out;
}
Tensor _mm_plus_mm(
const Tensor& a,
const Tensor& b,
const Tensor& c,
const Tensor& d,
Tensor& out) {
return _mm_plus_mm_out(out, a, b, c, d);
}
Tensor _alloc_from_pool(
const Tensor& self,
int64_t offset_bytes,
ScalarType dtype,
IntArrayRef size,
IntArrayRef stride) {
TORCH_CHECK(self.storage_offset() == 0);
// based on alias_with_sizes_and_strides from TensorShape.cpp
Tensor self_ = at::detail::make_tensor<TensorImpl>(
// c10::TensorImpl::VIEW,
Storage(self.storage()),
self.key_set(),
caffe2::TypeMeta::fromScalarType(dtype));
auto* self_tmp_ = self_.unsafeGetTensorImpl();
self_tmp_->set_storage_offset(
offset_bytes / static_cast<int64_t>(c10::elementSize(dtype)));
self_tmp_->set_sizes_and_strides(size, stride);
return self_;
}
// Similar to as_strided with the following differences
// - offset is added to the existing offset (rather than replacing it)
// - view tracking is disabled similar to unsafe_view
Tensor _reinterpret_tensor(
const Tensor& self,
IntArrayRef size,
IntArrayRef stride,
int64_t offset_increment) {
Tensor self_ = at::detail::make_tensor<TensorImpl>(
Storage(self.storage()), self.key_set(), self.dtype());
auto* self_tmp_ = self_.unsafeGetTensorImpl();
self_tmp_->set_storage_offset(self.storage_offset() + offset_increment);
self_tmp_->set_sizes_and_strides(size, stride);
return self_;
}
static void accumulate_grad_(const Tensor& variable, const Tensor& new_grad) {
at::Tensor& grad = variable.mutable_grad();
if (new_grad.device() != kMeta) {
// Do not call into this codepath from C++ frontend, instead call directly
// into accumulateGrad with num_expected_refs set to 1 Here,
// num_expected_refs is set to 2 to steal the gradient when this is called
// from Python
torch::autograd::AccumulateGrad::accumulateGrad(
variable,
grad,
new_grad,
2 /* num_expected_refs */,
[&grad](at::Tensor&& grad_update) { grad = std::move(grad_update); });
} else {
// no shape checking for `device="meta"` to workaround FSDP inplace mutation
if (!grad.defined()) {
grad = new_grad;
}
}
}
TORCH_LIBRARY_FRAGMENT(inductor, m) {
m.def(
"_mm_plus_mm(Tensor a, Tensor b, Tensor c, Tensor d, Tensor(t!) out) -> Tensor(t!)",
dispatch(c10::DispatchKey::CompositeExplicitAutograd, _mm_plus_mm),
{at::Tag::pt2_compliant_tag});
m.def(
"_alloc_from_pool(Tensor self, int offset_bytes, ScalarType dtype, int[] size, int[] stride) -> Tensor",
_alloc_from_pool,
{at::Tag::pt2_compliant_tag});
m.def(
"_reinterpret_tensor(Tensor self, int[] size, int[] stride, int offset_increment=0) -> Tensor",
dispatch(
c10::DispatchKey::CompositeExplicitAutograd, _reinterpret_tensor),
{at::Tag::pt2_compliant_tag});
m.def(
"accumulate_grad_(Tensor variable, Tensor new_grad) -> ()",
dispatch(c10::DispatchKey::CompositeExplicitAutograd, accumulate_grad_),
{at::Tag::pt2_compliant_tag});
}
} // namespace torch::inductor
|