1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
|
#include <torch/csrc/autograd/input_metadata.h>
// TODO: we may be able to move some imports from input_metadata.h to here, but
// it seems that function.h transitively depends on some of them.
namespace torch::autograd {
namespace {
MetadataShape compute_variant_shape(const at::Tensor& input) {
if (input.is_nested() && !input.unsafeGetTensorImpl()->is_python_dispatch()) {
auto nested_size = input._nested_tensor_size();
return MetadataShape{std::in_place_type<at::Tensor>, nested_size};
}
return MetadataShape{std::in_place_type<SymIntSmallVec>, input.sym_sizes()};
}
bool is_python_dispatch(const at::Tensor& tensor) {
return tensor.unsafeGetTensorImpl()->is_python_dispatch();
}
bool is_cpp_nested_tensor(const at::Tensor& tensor) {
return tensor.is_nested() && !is_python_dispatch(tensor);
}
} // namespace
InputMetadata::InputMetadata(
const at::TensorOptions& options,
MetadataShape input_shape,
bool is_tensor_subclass,
bool is_nested)
: options_{options},
shape_{std::move(input_shape)},
is_tensor_subclass_{is_tensor_subclass},
is_nested_{is_nested},
was_default_constructed_{false} {
auto device_ = options.device();
stream_ = c10::impl::getDeviceGuardImpl(device_.type())->getStream(device_);
}
InputMetadata::InputMetadata(const at::Tensor& t)
: InputMetadata(
t.options(),
compute_variant_shape(t),
is_python_dispatch(t),
t.is_nested()) {}
at::Tensor InputMetadata::zeros_like() const {
TORCH_CHECK(
!is_nested_, "Zeros is not currently supported for nested tensors.")
return at::zeros_symint(shape_as_dim_vector(), options_);
}
at::Tensor InputMetadata::maybe_reduce(
const size_t i,
at::Tensor grad,
const std::function<std::string(const std::string&)>& format_error) const {
auto fail = [&]() {
const auto message = incompatible_shape_error_message(i, grad);
TORCH_CHECK(false, format_error(message.str()));
};
// Nested tensor makes my brain explode, so I've just hard-coded the logic
// for this case, at risk of code duplication. This logic does NOT do the
// careful oblivious logic as seen below
if (is_nested_ || is_cpp_nested_tensor() || grad.is_nested() ||
::torch::autograd::is_cpp_nested_tensor(grad)) {
if (!is_same_shape(grad)) {
if (is_expandable_to_shape(grad)) {
return reduce_grad(grad);
} else {
fail();
}
} else {
return grad;
}
}
auto shape = shape_as_dim_vector();
auto desired = grad.sym_sizes();
size_t ndim = shape.size();
size_t target_dim = desired.size();
if (ndim > target_dim) {
fail();
}
bool needs_reduce = false;
for (const auto i : c10::irange(ndim)) {
const auto& size = shape[ndim - i - 1];
const auto& target = desired[target_dim - i - 1];
// The conditions here are written carefully so that we are able to
// infer deferred runtime asserts
if (TORCH_GUARD_SIZE_OBLIVIOUS(size.sym_eq(1))) {
// NB: we could short circuit this once needs_reduce is true but there's
// no point since the reduction function will guard on this anyway
if (!c10::definitely_true(size.sym_eq(target), __FILE__, __LINE__)) {
needs_reduce = true;
}
} else {
if (!size.sym_eq(target).expect_true(__FILE__, __LINE__)) {
fail();
}
}
}
if (ndim != target_dim) {
needs_reduce = true;
}
if (needs_reduce) {
return reduce_grad(grad);
} else {
return grad;
}
}
bool InputMetadata::is_same_shape(const at::Tensor& grad) const {
if (!is_nestedness_same(grad)) {
return false;
}
if (is_cpp_nested_tensor()) {
return grad._nested_tensor_size().is_same_size(shape_as_tensor());
}
return grad.sym_sizes().equals(shape_as_dim_vector());
}
bool InputMetadata::is_expandable_to_shape(const at::Tensor& grad) const {
if (!maybe_expandable_to(grad)) {
return false;
}
return at::is_expandable_to(shape_as_dim_vector(), grad.sym_sizes());
}
at::Tensor InputMetadata::reduce_grad(at::Tensor& grad) const {
// reduce_grad should only be called if is_expandable_to_shape returns true.
TORCH_INTERNAL_ASSERT(maybe_expandable_to(grad));
return at::sum_to(std::move(grad), shape_as_dim_vector());
}
std::stringstream InputMetadata::incompatible_shape_error_message(
const size_t index,
const at::Tensor& grad) const {
std::stringstream ss{};
ss << "invalid gradient at index " << index << " - got ";
if (::torch::autograd::is_cpp_nested_tensor(grad)) {
ss << grad._nested_tensor_size();
} else {
ss << grad.sym_sizes();
}
ss << " but expected shape compatible with ";
if (is_cpp_nested_tensor()) {
ss << shape_as_tensor();
} else {
ss << shape_as_dim_vector();
}
return ss;
}
bool InputMetadata::is_cpp_nested_tensor() const {
bool ret = std::holds_alternative<at::Tensor>(shape_);
TORCH_INTERNAL_ASSERT(ret == (is_nested_ && !is_tensor_subclass_))
return ret;
}
c10::SymIntArrayRef InputMetadata::shape_as_dim_vector() const {
const auto& dim_shape = std::get<SymIntSmallVec>(shape_);
return c10::SymIntArrayRef(dim_shape.data(), dim_shape.size());
}
// Danger: not thread safe, caller must protect with lock
SymIntSmallVec& InputMetadata::mutable_shape_as_dim_vector() {
return std::get<SymIntSmallVec>(shape_);
}
bool InputMetadata::is_nestedness_same(const at::Tensor& grad) const {
return (
grad.is_nested() == is_nested_ &&
::torch::autograd::is_cpp_nested_tensor(grad) == is_cpp_nested_tensor());
}
at::Tensor InputMetadata::shape_as_tensor() const {
return std::get<at::Tensor>(shape_);
}
bool InputMetadata::maybe_expandable_to(const at::Tensor& grad) const {
// This is the initial step to determine whether or not the tensor represented
// by input_metadata is expandable to grad based on is-nestedness information
// alone. If this function returns true, then is_expandable_to_shape will be
// called. We support the following 3 types of expansion:
bool grad_is_nested = grad.is_nested();
if (!is_nested_ && !grad_is_nested) {
// Normal case (no NestedTensors are involved)
// (1) plain Tensor -> plain Tensor
return true;
} else {
// (2) python NT -> python NT
// (3) plain Tensor -> python NT
return (
grad_is_nested && is_python_dispatch(grad) &&
(!is_nested_ || is_tensor_subclass_));
}
}
} // namespace torch::autograd
|