1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
|
#pragma once
#include <ATen/Tensor.h>
namespace torch {
namespace autograd {
namespace utils {
// Helper functions to enforce the "Gradient Layout Contract" described in
// torch/csrc/autograd/functions/accumulate_grad.h.
// Checks if grad obeys the contract with variable.
inline bool obeys_layout_contract(
const at::Tensor& grad,
const at::Tensor& variable) {
TORCH_INTERNAL_ASSERT(!grad.is_sparse());
TORCH_INTERNAL_ASSERT(!variable.is_sparse());
TORCH_INTERNAL_ASSERT(!grad.is_sparse_csr());
TORCH_INTERNAL_ASSERT(!variable.is_sparse_csr());
if (variable.is_nested()) {
// TODO: Nested Tensor does not have an implementation of detach. The
// current implementation of nested tensor likely does obey the gradient
// contract and should return true, but this would likely change in the
// future
return false;
} else if (variable.is_non_overlapping_and_dense()) {
// Only look at stride for dimensions that are not of size 1.
const auto& grad_sizes = grad.sizes();
const auto& grad_strides = grad.strides();
const auto& variable_strides = variable.strides();
for (const auto idx : c10::irange(grad_sizes.size())) {
if (grad_sizes[idx] != 1) {
if (grad_strides[idx] != variable_strides[idx]) {
return false;
}
} else {
// This should not be needed but we don't check if a Tensor has views
// before stashing it. And 0-strided Tensors of size 1 are actually
// views for ops like cat.
// TODO: Actually detect views in the accumulateGrad function so that
// this Tensor is not considered at all.
if (grad_strides[idx] == 0) {
return false;
}
}
}
return true;
} else {
return grad.is_contiguous(at::MemoryFormat::Contiguous);
}
}
// Creates a clone of new_grad that obeys the contract with variable.
// The clone should attach to new_grad's history if GradMode::is_enabled().
inline at::Tensor clone_obey_contract(
const at::Tensor& new_grad,
const at::Tensor& variable) {
if (variable.is_non_overlapping_and_dense()) {
// (1)
// Does this dicey-looking sequence attach the result to new_grad's
// history if GradMode::is_enabled()? Yes, and @alband says it should.
return std::move(new_grad
.new_empty_strided(
variable.sizes(),
variable.strides(),
variable.options().memory_format(c10::nullopt))
.copy_(new_grad));
} else {
// (2)
return new_grad.clone(at::MemoryFormat::Contiguous);
}
}
} // namespace utils
} // namespace autograd
} // namespace torch
|