1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
|
#pragma once
#include <ATen/ExpandUtils.h>
#include <ATen/NestedTensorImpl.h>
#include <ATen/core/Tensor.h>
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/Stream.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/core/TensorImpl.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/util/DimVector.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <c10/util/variant.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/zeros.h>
#endif
#include <cstdint>
#include <utility>
namespace torch {
namespace autograd {
using SymIntSmallVec = c10::SmallVector<c10::SymInt, c10::kDimVectorStaticSize>;
using MetadataShape = c10::variant<SymIntSmallVec, at::Tensor>;
/**
* Records TensorOptions, shape of the tensor, whether or not the Python
* dispatch key is set (tensor subclass), and, where applicable, the stream the
* corresponding operation took place on.
*
* If is_valid() is false, then the corresponding input is not used and may be
* an undefined tensor.
*/
struct InputMetadata {
InputMetadata() = default;
InputMetadata(
const at::TensorOptions options,
MetadataShape input_shape,
bool is_tensor_subclass)
: options_{options},
shape_{input_shape},
is_tensor_subclass_{is_tensor_subclass} {
auto device_ = options.device();
stream_ = c10::impl::getDeviceGuardImpl(device_.type())->getStream(device_);
}
InputMetadata(const at::Tensor& t)
: InputMetadata(
t.options(),
compute_variant_shape(t),
t.unsafeGetTensorImpl()->is_python_dispatch()) {}
const at::TensorOptions options() const {
return options_;
}
caffe2::TypeMeta dtype() const {
return options_.dtype();
}
at::Device device() const {
return options_.device();
}
at::Layout layout() const {
return options_.layout();
}
c10::Stream stream() const {
return stream_;
}
bool is_tensor_subclass() const {
return is_tensor_subclass_;
}
at::Tensor zeros_like() const {
TORCH_CHECK(
!is_nested_tensor(),
"Zeros is not currently supported for nested tensors.")
return at::zeros_symint(shape_as_dim_vector(), options_);
}
bool is_same_shape(const at::Tensor& grad) const {
TORCH_CHECK(
grad.is_nested() == is_nested_tensor(),
"Both grad and InputMetadata need to be either nested or non nested tensors.")
if (grad.is_nested()) {
return at::native::get_nested_size_tensor(grad).is_same_size(
shape_as_tensor());
}
return grad.sym_sizes().equals(shape_as_dim_vector());
}
bool is_expandable_to_shape(const at::Tensor& grad) const {
// Currently NestedTensors are not expandable. If this support is added then
// updates to reduce_grad will be needed
TORCH_CHECK(
grad.is_nested() == is_nested_tensor(),
"Both grad and InputMetadata need to be either nested or non nested tensors.")
return grad.is_nested()
? false
: at::is_expandable_to(shape_as_dim_vector(), grad.sym_sizes());
}
at::Tensor reduce_grad(at::Tensor& grad) const {
// Currently reduce_grad is only called if is_expandable_to_shape returns
// true For nested tensors this always returns False, so this check
// shouldn't fail
TORCH_INTERNAL_ASSERT(!grad.is_nested() && !is_nested_tensor())
return at::sum_to(std::move(grad), shape_as_dim_vector());
}
std::stringstream incompatible_shape_error_message(
const size_t index,
const at::Tensor& grad) const {
std::stringstream ss;
ss << "invalid gradient at index " << index << " - got ";
if (grad.is_nested()) {
ss << at::native::get_nested_size_tensor(grad);
} else {
ss << grad.sizes();
}
ss << " but expected shape compatible with ";
if (is_nested_tensor()) {
ss << shape_as_tensor();
} else {
ss << c10::asIntArrayRefSlow(shape_as_dim_vector());
}
return ss;
}
private:
bool is_nested_tensor() const {
return (c10::holds_alternative<at::Tensor>(shape_));
}
MetadataShape compute_variant_shape(const at::Tensor& input) {
if (input.is_nested()) {
auto nested_size = at::native::get_nested_size_tensor(input);
return MetadataShape{c10::in_place_type<at::Tensor>, nested_size};
}
return MetadataShape{c10::in_place_type<SymIntSmallVec>, input.sym_sizes()};
}
c10::SymIntArrayRef shape_as_dim_vector() const {
const auto& dim_shape = c10::get<SymIntSmallVec>(shape_);
return c10::SymIntArrayRef(dim_shape.data(), dim_shape.size());
}
at::Tensor shape_as_tensor() const {
return c10::get<at::Tensor>(shape_);
}
const at::TensorOptions options_;
MetadataShape shape_;
c10::Stream stream_ = c10::Stream(c10::Stream::Default::DEFAULT, device());
bool is_tensor_subclass_ = false;
};
} // namespace autograd
} // namespace torch
|