1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
|
#include <torch/csrc/autograd/functions/accumulate_grad.h>
#include <torch/csrc/autograd/functions/basic_ops.h>
#include <torch/csrc/autograd/functions/tensor.h>
#include <torch/csrc/autograd/functions/utils.h>
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/autograd/variable.h>
#include <cstdint>
#include <stdexcept>
#include <utility>
using at::Tensor;
namespace torch {
namespace autograd {
// AccumulateGrad sets sequence_nr to the max value so it's always called
// ASAP during backwards.
AccumulateGrad::AccumulateGrad(Variable variable_)
: Node(/*sequence_nr=*/UINT64_MAX), variable(std::move(variable_)) {
add_input_metadata(variable);
}
auto AccumulateGrad::apply(variable_list&& grads) -> variable_list {
check_input_variables("AccumulateGrad", grads, 1, 0);
if (!grads[0].defined())
return {};
if (variable.grad_fn())
throw std::logic_error(
"leaf variable has been moved into the graph interior");
if (!variable.requires_grad())
return {};
// std::move(grads[0]) to avoid bumping up refcount
at::Tensor new_grad = callHooks(variable, std::move(grads[0]));
// Acquire lock to here protect thread safety on variable, this ensures
// AccumulateGrad does not race to shared variable from different threads
// when updating the gradients. We don't ensure thread safety on hooks
// and rely on user to provide thread safe hooks
// see Note [Thread Safety on Autograd Node]
std::lock_guard<std::mutex> lock(mutex_);
at::Tensor& grad = variable.mutable_grad();
// If the function has post hooks (for example, a DDP allreduce hook),
// call_function in Engine.cpp will temporarily bump the expected refcount
// by one, hence the addition of !post_hooks().empty() for 'num_expected_refs'
// in addition to the one reference that we're holding.
// 'num_expected_refs' is used to determine whether or not we should clone
// the grad or can steal the grad.
accumulateGrad(
variable,
grad,
new_grad,
1 + !post_hooks().empty() /* num_expected_refs */,
[&grad](at::Tensor&& grad_update) { grad = std::move(grad_update); });
return variable_list();
}
} // namespace autograd
} // namespace torch
|