1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
|
#include <torch/csrc/autograd/saved_variable.h>
#include <torch/csrc/autograd/anomaly_mode.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/autograd/variable.h>
#include <ATen/Tensor.h>
#include <cstdint>
#include <list>
#include <memory>
#include <sstream>
namespace torch {
namespace autograd {
SavedVariable::SavedVariable(
const Variable& variable,
bool is_output,
bool is_inplace_on_view) {
if (variable.defined()) {
// Note [Inference tensor cannot be saved for backward]
// Invariant:
// You can't save an inference tensor for backwards.
// If an inference tensor was saved for backward in an autograd session and
// then you reenter inference mode and make an inplace update to the tensor
// without bumping version_counter, it'll lead to silent wrong result when
// you do backward() for the previous autograd session. Technically we
// don't have to check here since it'll fail when querying `current_version`
// on the inference tensor, but we can give a much better error message
// here.
//
// Note in the documentation we say "inference tensor cannot participate
// in autograd" which is more restrictive than the invariant. In practice
// the check is more permissive and only error out when an inference tensor
// is saved for backward. Whether a tensor is saved for backward is
// determined by derivative formula and thus varies op by op, so by saying
// "no inference tensor in autograd" it's easier for users to understand and
// follow.
TORCH_CHECK(
!variable.is_inference(),
"Inference tensors cannot be saved for backward. To work around "
"you can make a clone to get a normal tensor and use it in autograd.")
was_default_constructed_ = false;
const auto& version_counter = impl::version_counter(variable);
saved_version_ = version_counter.current_version();
is_leaf_ = variable.is_leaf();
is_output_ = is_output;
is_inplace_on_view_ = is_inplace_on_view;
if (is_inplace_on_view) {
TORCH_INTERNAL_ASSERT(!is_leaf_ && is_output);
weak_grad_fn_ = variable.grad_fn();
}
auto maybe_hooks = get_default_hooks();
if (maybe_hooks) {
save_metadata(variable);
set_hooks_and_pack_data(std::move(maybe_hooks), variable);
return;
}
// If the variable is a leaf or is not an output, we can safely save the
// original variable without running the risk of reference cycles.
// 1. If the variable is not an output, its grad_fn has already been fully
// created and in particular will be a different Node than the one
// we are currently constructing (the one that owns this SavedVariable).
// 2. If the variable is a leaf, it only has weak reference to the
// grad_accumulator which cannot create a cycle. In those cases, we save the
// original variable and don't need further processing.
if (!is_output || is_leaf_) {
saved_original_ = true;
data_ = variable;
return;
}
save_metadata(variable);
// Only do this if we actually need to.
data_ = variable.tensor_data();
}
}
void SavedVariable::save_metadata(const Variable& data) {
// Save output number, version counter and fw_grad if needed
output_nr_ = data.output_nr();
version_counter_ = impl::version_counter(data);
if (is_leaf_) {
grad_accumulator_ = impl::grad_accumulator(data);
requires_grad_ = data.requires_grad();
} else if (!is_output_) {
grad_fn_ = data.grad_fn();
}
// TODO(albanD) This needs to be updated when moving to multiple levels
const auto& fw_grad = data._fw_grad(/* level */ 0);
if (fw_grad.defined()) {
fw_grad_ = std::make_shared<ForwardGrad>();
fw_grad_->set_value(fw_grad, /* level */ 0);
}
}
std::unique_ptr<SavedVariableHooks> SavedVariable::get_default_hooks() {
return Engine::get_default_engine().get_default_saved_variable_hooks();
}
void SavedVariable::reset_data() {
hooks_.reset();
grad_fn_.reset();
data_.reset();
}
SavedVariable::SavedVariable(
const c10::optional<Variable>& variable,
bool is_output,
bool is_inplace_on_view)
: SavedVariable(
variable.has_value() ? *variable : Variable(),
is_output,
is_inplace_on_view) {}
Variable SavedVariable::unpack(std::shared_ptr<Node> saved_for) const {
if (was_default_constructed_) {
return Variable();
}
if (!data_.defined()) {
TORCH_CHECK(hooks_, ERR_BACKWARD_TWICE);
}
// We want grad_fn here to provide the most helpful debug message to the user
// if versions don't match
auto grad_fn = is_inplace_on_view_ ? weak_grad_fn_.lock()
: !hooks_ ? saved_original_ ? data_.grad_fn() : nullptr
: grad_fn_;
if (!is_leaf_ && !grad_fn) {
TORCH_INTERNAL_ASSERT(saved_for, "No grad_fn for non-leaf saved tensor");
grad_fn = std::move(saved_for);
}
// Only check version counter in the case without hooks
// If user provides hooks, we can't track versions through the hooks
if (!hooks_) {
auto current_version = saved_original_
? impl::version_counter(data_).current_version()
: version_counter_.current_version();
if (saved_version_ != current_version) {
std::stringstream message;
message
<< "one of the variables needed for gradient computation has been "
"modified by an inplace operation: ["
<< data_.toString() << " ";
if (data_.is_nested()) {
message << data_._nested_tensor_size() << "]";
} else {
message << data_.sizes() << "]";
}
if (grad_fn) {
message << ", which is output " << output_nr_ << " of "
<< grad_fn->name() << ",";
}
message << " is at version " << current_version << "; expected version "
<< saved_version_ << " instead.";
if (!AnomalyMode::is_enabled()) {
message << " Hint: enable anomaly detection to find the operation "
"that failed to compute its gradient, with torch.autograd."
"set_detect_anomaly(True).";
} else {
message
<< " Hint: the backtrace further above shows the operation "
"that failed to compute its gradient. The variable in question "
"was changed in there or anywhere later. Good luck!";
}
TORCH_CHECK(false, message.str());
}
}
// The version counter is correct.
// Additionnally, if we deal with a non-leaf variable, we have its correct
// grad_fn.
// If we have the original variable, we simply return it
if (!hooks_ && saved_original_) {
return data_;
}
const auto data = hooks_ ? hooks_->call_unpack_hook() : data_;
// NB: saved views are unpacked as normal Variables (not views) even though
// they still share the same storage. This works only because we never call
// in-place functions on unpacked variables.
Variable var;
if (grad_fn) {
var = make_variable(data, Edge(std::move(grad_fn), output_nr_));
} else {
var = make_variable(data, requires_grad_);
}
impl::set_version_counter(var, version_counter_);
// If a Variable is a leaf (no grad_fn saved), and it requires_grad, then we
// should have saved the grad accumulator. Even if the Variable is no longer
// alive, the accumulator should be kept alive by the references in the
// graph.
if (is_leaf_ && requires_grad_) {
TORCH_INTERNAL_ASSERT(
!grad_accumulator_.expired(), "No grad accumulator for a saved leaf");
}
impl::set_grad_accumulator(var, grad_accumulator_);
// NB: var here is never a view so there is no need to make anything special
// for the case where the saved Tensor was a view. This whole argument relies
// on the fact that the Tensor returned by this function is never
// modified in-place.
if (fw_grad_ && !fw_grad_->empty()) {
// TODO(albanD) This needs to be updated when moving to multiple levels
auto new_fw_grad = fw_grad_->value(/* level */ 0);
var._set_fw_grad(new_fw_grad, /* level */ 0, /* is_inplace_op */ false);
}
return var;
}
void SavedVariable::set_hooks_and_pack_data(
std::unique_ptr<SavedVariableHooks>&& hooks,
const Variable& data) {
hooks_ = std::move(hooks);
at::NoGradGuard guard;
const auto version = impl::version_counter(data).current_version();
hooks_->call_pack_hook(saved_original_ ? data.detach() : data);
TORCH_CHECK(
version == impl::version_counter(data).current_version(),
"A saved tensor pack hook is modifying its input in place. "
"Tensors provided as input to pack hook can not be modified by "
"in-place operations as this can lead to unexpected side-effects. "
"Please open an issue if you need to perform in-place operations on "
"the input to a pack hook.");
}
void SavedVariable::register_hooks(
std::unique_ptr<SavedVariableHooks>&& hooks) {
TORCH_INTERNAL_ASSERT(hooks);
TORCH_CHECK(
!hooks_,
"Calling register_hooks on a saved tensor whose hooks have already been set. "
"Hint: only one pair of hooks is allowed at a time.");
if (!data_.defined()) {
if (!was_default_constructed_) {
TORCH_CHECK(
false,
"Calling register_hooks on a saved tensor after it has been freed. "
"Saved intermediate values of the graph are freed when you call "
".backward() or autograd.grad(). Specify retain_graph=True if you "
"need to backward through the graph a second time or if you need to "
"access saved variables after calling backward.");
} else {
TORCH_CHECK(
false,
"Calling register_hooks on a saved tensor with value None is forbidden");
}
}
// If we didn't save the original variable, we already saved metadata
if (saved_original_) {
save_metadata(data_);
}
set_hooks_and_pack_data(std::move(hooks), data_);
data_.reset();
}
const char* ERR_BACKWARD_TWICE =
"Trying to backward through the graph a second time (or directly access saved "
"tensors after they have already been freed). Saved intermediate values "
"of the graph are freed when you call .backward() or autograd.grad(). Specify "
"retain_graph=True if you need to backward through the graph a second time or "
"if you need to access saved tensors after calling backward.";
} // namespace autograd
} // namespace torch
|